{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "tflite_text_classification.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "dweksnWs1eZb"
      ]
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_VsBo9D7GJ8G"
      },
      "source": [
        "```\n",
        "Copyright 2021 The IREE Authors\n",
        "\n",
        "Licensed under the Apache License v2.0 with LLVM Exceptions.\n",
        "See https://llvm.org/LICENSE.txt for license information.\n",
        "SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3Mjqf_WmSkAs"
      },
      "source": [
        "# TFLite text classification sample with IREE\n",
        "\n",
        "This notebook demonstrates how to download, compile, and run a TFLite model with IREE.  It looks at the pretrained [text classification](https://www.tensorflow.org/lite/examples/text_classification/overview) model, and shows how to run it with both TFLite and IREE.  The model predicts if a sentence's sentiment is positive or negative, and is trained on a database of IMDB movie reviews.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W71slxNjS3SB"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "T8WMt2sPS4ft"
      },
      "source": [
        "%%capture\n",
        "!python -m pip install iree-compiler-snapshot iree-runtime-snapshot iree-tools-tflite-snapshot -f https://github.com/google/iree/releases/latest\n",
        "!pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "code",
        "id": "7L4_gnRVVBdi"
      },
      "source": [
        "import numpy as np\n",
        "import urllib.request\n",
        "import pathlib\n",
        "import tempfile\n",
        "import re\n",
        "import tflite_runtime.interpreter as tflite\n",
        "\n",
        "from iree import runtime as iree_rt\n",
        "from iree.compiler import compile_str\n",
        "from iree.tools import tflite as iree_tflite\n",
        "\n",
        "ARTIFACTS_DIR = pathlib.Path(tempfile.gettempdir(), \"iree\", \"colab_artifacts\")\n",
        "ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dweksnWs1eZb"
      },
      "source": [
        "### Load the TFLite model\n",
        "\n",
        "1.   Download files for the pretrained model\n",
        "2.   Extract model metadata used for input pre-processing and output post-processing\n",
        "3.   Define helper functions for pre- and post-processing\n",
        "\n",
        "These steps will differ from model to model.  Consult the model source or reference documentation for details.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cUTaotkV7taP",
        "outputId": "778f3dbb-e540-4840-901e-1ede5a8b0cf7"
      },
      "source": [
        "#@title Download pretrained text classification model\n",
        "MODEL_URL = \"https://storage.googleapis.com/download.tensorflow.org/models/tflite/text_classification/text_classification_v2.tflite\"\n",
        "urllib.request.urlretrieve(MODEL_URL, ARTIFACTS_DIR.joinpath(\"text_classification.tflite\"))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(PosixPath('/tmp/iree/colab_artifacts/text_classification.tflite'),\n",
              " <http.client.HTTPMessage at 0x7f86b3d0add0>)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ADu-gSDnIm2B",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "c563db87-ad8c-4f72-8b22-afd5c9a1717f"
      },
      "source": [
        "#@title Extract model vocab and label metadata\n",
        "!unzip -d {ARTIFACTS_DIR} {ARTIFACTS_DIR}/text_classification.tflite\n",
        "\n",
        "# Load the vocab file into a dictionary.  It contains the most common 1,000\n",
        "# words in the English language, mapped to an integer.\n",
        "vocab = {}\n",
        "with open(ARTIFACTS_DIR.joinpath(\"vocab.txt\")) as vocab_file:\n",
        "  for line in vocab_file:\n",
        "    (key, val) = line.split()\n",
        "    vocab[key] = int(val)\n",
        "\n",
        "# Text will be labeled as either 'Positive' or 'Negative'.\n",
        "with open(ARTIFACTS_DIR.joinpath(\"labels.txt\")) as label_file:\n",
        "  labels = label_file.read().splitlines()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Archive:  /tmp/iree/colab_artifacts/text_classification.tflite\n",
            " extracting: /tmp/iree/colab_artifacts/labels.txt  \n",
            " extracting: /tmp/iree/colab_artifacts/vocab.txt  \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Z1WKEZY1JH6E"
      },
      "source": [
        "#@title Input and output processing\n",
        "\n",
        "# Input text will be encoded as an integer array of fixed length 256.  The \n",
        "# input sentence will be mapped to integers from the vocab dictionary, and the \n",
        "# empty array spaces are filled with padding.\n",
        "\n",
        "SENTENCE_LEN = 256\n",
        "START = \"<START>\"\n",
        "PAD = \"<PAD>\"\n",
        "UNKNOWN = \"<UNKNOWN>\"\n",
        "\n",
        "def tokenize_input(text):\n",
        "  output = np.empty([1, SENTENCE_LEN], dtype=np.int32)\n",
        "  output.fill(vocab[PAD])\n",
        "\n",
        "  # Remove capitalization and punctuation from the input text.\n",
        "  text_split = text.split()\n",
        "  text_split = [text.lower() for text in text_split]\n",
        "  text_split = [re.sub(r\"[^\\w\\s']\", '', text) for text in text_split]\n",
        "\n",
        "  # Prepend <START>.\n",
        "  index = 0\n",
        "  output[0][index] = vocab[START]\n",
        "  index += 1\n",
        "\n",
        "  for word in text_split:\n",
        "    output[0][index] = vocab[word] if word in vocab else vocab[UNKNOWN]\n",
        "    index += 1\n",
        "\n",
        "  return output\n",
        "\n",
        "\n",
        "def interpret_output(output):\n",
        "  if output[0] >= output[1]:\n",
        "    label = labels[0]\n",
        "    confidence = output[0]\n",
        "  else:\n",
        "    label = labels[1]\n",
        "    confidence = output[1]\n",
        "\n",
        "  print(\"Label: \" + label + \"\\nConfidence: \" + str(confidence))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xq00JknF0cOz",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "866b81ac-74fb-4133-f229-980bb1a5ebae"
      },
      "source": [
        "#@title Text samples\n",
        "positive_text = \"This is the best movie I've seen in recent years. Strongly recommend it!\"\n",
        "negative_text = \"What a waste of my time.\"\n",
        "\n",
        "print(positive_text)\n",
        "print(tokenize_input(positive_text))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "This is the best movie I've seen in recent years. Strongly recommend it!\n",
            "[[   1   13    8    3  117   19  206  109   10 1134  152 2301  385   11\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
            "     0    0    0    0]]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fvxhY1W7X7Gf"
      },
      "source": [
        "## Run using TFLite\n",
        "\n",
        "Overview:\n",
        "\n",
        "1.  Load the TFLite model in a [TFLite Interpreter](https://www.tensorflow.org/api_docs/python/tf/lite/Interpreter)\n",
        "2.   Allocate tensors and get the input and output shape information\n",
        "3.   Invoke the TFLite Interpreter to test the text classification function"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Nv8x81S5eV5g"
      },
      "source": [
        "interpreter = tflite.Interpreter(\n",
        "      model_path=str(ARTIFACTS_DIR.joinpath(\"text_classification.tflite\")))\n",
        "interpreter.allocate_tensors()\n",
        "input_details = interpreter.get_input_details()\n",
        "output_details = interpreter.get_output_details()\n",
        "\n",
        "def classify_text_tflite(text):\n",
        "  interpreter.set_tensor(input_details[0]['index'], tokenize_input(text))\n",
        "  interpreter.invoke()\n",
        "  output_data = interpreter.get_tensor(output_details[0]['index'])\n",
        "  interpret_output(output_data[0])"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "kpxfU88ckFxI",
        "outputId": "4ab615c4-8492-464a-c5b2-d9107d6b22c7"
      },
      "source": [
        "print(\"Invoking text classification with TFLite\\n\")\n",
        "print(positive_text)\n",
        "classify_text_tflite(positive_text)\n",
        "print()\n",
        "print(negative_text)\n",
        "classify_text_tflite(negative_text)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Invoking text classification with TFLite\n",
            "\n",
            "This is the best movie I've seen in recent years. Strongly recommend it!\n",
            "Label: Positive\n",
            "Confidence: 0.8997293\n",
            "\n",
            "What a waste of my time.\n",
            "Label: Negative\n",
            "Confidence: 0.6275043\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_FgwUBtm7Y5n"
      },
      "source": [
        "## Run using IREE\n",
        "\n",
        "Overview:\n",
        "\n",
        "1.   Import the TFLite model to TOSA MLIR \n",
        "2.   Compile the TOSA MLIR into an IREE flatbuffer and VM module\n",
        "3.   Run the VM module through IREE's runtime to test the text classification function\n",
        "\n",
        "Both runtimes should generate the same output.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZifuTzgEXLJn"
      },
      "source": [
        "# Convert TFLite model to TOSA MLIR with IREE's import tool.\n",
        "IREE_TFLITE_TOOL = iree_tflite.get_tool('iree-import-tflite')\n",
        "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite -o={ARTIFACTS_DIR}/text_classification.mlir\n",
        "\n",
        "with open(ARTIFACTS_DIR.joinpath(\"text_classification.mlir\")) as mlir_file:\n",
        "  tosa_mlir = mlir_file.read()\n",
        "\n",
        "# Manually insert \"iree.module.export\" attribute until it is removed. \n",
        "# https://github.com/google/iree/issues/3968\n",
        "modified_tosa_mlir = tosa_mlir.replace('outputs = \"Identity\"}','outputs = \"Identity\"}, iree.module.export')\n",
        "\n",
        "# The generated .mlir file could now be saved and used outside of Python, with\n",
        "# IREE native tools or in apps, etc."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-C-5X4C0D0La",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4e7a41f5-a886-454f-fad2-db7e210cf10e"
      },
      "source": [
        "# The model contains very large constants, so recompile a truncated version to print.\n",
        "!{IREE_TFLITE_TOOL} {ARTIFACTS_DIR}/text_classification.tflite -o={ARTIFACTS_DIR}/text_classification_truncated.mlir -mlir-elide-elementsattrs-if-larger=50\n",
        "\n",
        "with open(ARTIFACTS_DIR.joinpath(\"text_classification_truncated.mlir\")) as truncated_mlir_file:\n",
        "  truncated_tosa_mlir = truncated_mlir_file.read()\n",
        "  print(truncated_tosa_mlir)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "module  {\n",
            "  func @main(%arg0: tensor<1x256xi32>) -> tensor<1x2xf32> attributes {tf.entry_function = {inputs = \"input_5\", outputs = \"Identity\"}} {\n",
            "    %0 = \"tosa.const\"() {value = opaque<\"_\", \"0xDEADBEEF\"> : tensor<10003x16xf32>} : () -> tensor<10003x16xf32>\n",
            "    %1 = \"tosa.const\"() {value = opaque<\"_\", \"0xDEADBEEF\"> : tensor<16x16xf32>} : () -> tensor<16x16xf32>\n",
            "    %2 = \"tosa.const\"() {value = dense<[-0.00698487554, 0.0294856895, 0.0699710473, 0.130019352, -0.0490558445, 0.0987673401, 0.0744077861, 0.0948959812, -0.010937131, 0.0931261852, 0.0711835548, -0.0385615043, 9.962780e-03, 0.00283221388, 0.112116851, 0.0134318024]> : tensor<16xf32>} : () -> tensor<16xf32>\n",
            "    %3 = \"tosa.const\"() {value = dense<[[0.091361463, -1.23269629, 1.33242488, 0.92142266, -0.445623249, 0.849273681, -1.27237022, 1.28574562, 0.436188251, -0.963210225, 0.745473146, -0.255745709, -1.4491415, -1.4687326, 0.900665163, -1.36293614], [-0.0968776941, 0.771379471, -1.36363328, -1.1110599, -0.304591209, -1.05579722, 0.795746565, -1.3122592, 0.352218777, 1.04682362, -1.18796027, -0.0409261398, 1.05883229, 1.48620188, -1.13325548, 1.03072512]]> : tensor<2x16xf32>} : () -> tensor<2x16xf32>\n",
            "    %4 = \"tosa.const\"() {value = dense<[0.043447677, -0.0434476472]> : tensor<2xf32>} : () -> tensor<2xf32>\n",
            "    %5 = \"tosa.const\"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32>\n",
            "    %6 = \"tosa.const\"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>\n",
            "    %7 = \"tosa.const\"() {value = dense<3.906250e-03> : tensor<f32>} : () -> tensor<f32>\n",
            "    %8 = \"tosa.transpose\"(%0, %5) : (tensor<10003x16xf32>, tensor<2xi32>) -> tensor<10003x16xf32>\n",
            "    %9 = \"tosa.reshape\"(%8) {new_shape = [1, 10003, 16]} : (tensor<10003x16xf32>) -> tensor<1x10003x16xf32>\n",
            "    %10 = \"tosa.reshape\"(%arg0) {new_shape = [1, 256]} : (tensor<1x256xi32>) -> tensor<1x256xi32>\n",
            "    %11 = \"tosa.gather\"(%9, %10) : (tensor<1x10003x16xf32>, tensor<1x256xi32>) -> tensor<1x256x16xf32>\n",
            "    %12 = \"tosa.reshape\"(%11) {new_shape = [1, 256, 16]} : (tensor<1x256x16xf32>) -> tensor<1x256x16xf32>\n",
            "    %13 = \"tosa.transpose\"(%12, %6) : (tensor<1x256x16xf32>, tensor<3xi32>) -> tensor<1x256x16xf32>\n",
            "    %14 = \"tosa.reduce_sum\"(%13) {axis = 1 : i64} : (tensor<1x256x16xf32>) -> tensor<1x1x16xf32>\n",
            "    %15 = \"tosa.reshape\"(%14) {new_shape = [1, 16]} : (tensor<1x1x16xf32>) -> tensor<1x16xf32>\n",
            "    %16 = \"tosa.reshape\"(%7) {new_shape = [1, 1]} : (tensor<f32>) -> tensor<1x1xf32>\n",
            "    %17 = \"tosa.mul\"(%15, %16) {shift = 0 : i32} : (tensor<1x16xf32>, tensor<1x1xf32>) -> tensor<1x16xf32>\n",
            "    %18 = \"tosa.fully_connected\"(%17, %1, %2) : (tensor<1x16xf32>, tensor<16x16xf32>, tensor<16xf32>) -> tensor<1x16xf32>\n",
            "    %19 = \"tosa.clamp\"(%18) {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x16xf32>) -> tensor<1x16xf32>\n",
            "    %20 = \"tosa.fully_connected\"(%19, %3, %4) : (tensor<1x16xf32>, tensor<2x16xf32>, tensor<2xf32>) -> tensor<1x2xf32>\n",
            "    %21 = \"tosa.exp\"(%20) : (tensor<1x2xf32>) -> tensor<1x2xf32>\n",
            "    %22 = \"tosa.reduce_sum\"(%21) {axis = 1 : i64} : (tensor<1x2xf32>) -> tensor<1x1xf32>\n",
            "    %23 = \"tosa.reciprocal\"(%22) : (tensor<1x1xf32>) -> tensor<1x1xf32>\n",
            "    %24 = \"tosa.mul\"(%21, %23) {shift = 0 : i32} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x2xf32>\n",
            "    return %24 : tensor<1x2xf32>\n",
            "  }\n",
            "}\n",
            "\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "M3gXX2AF7aS9",
        "outputId": "b60627f7-a761-444a-b41a-058cea01fbfe"
      },
      "source": [
        "# Compile the TOSA MLIR into a VM module.\n",
        "compiled_flatbuffer = compile_str(modified_tosa_mlir, target_backends=[\"vmvx\"])\n",
        "vm_module = iree_rt.VmModule.from_flatbuffer(compiled_flatbuffer)\n",
        "\n",
        "# Register the module with a runtime context.\n",
        "config = iree_rt.Config(\"vmvx\")\n",
        "ctx = iree_rt.SystemContext(config=config)\n",
        "ctx.add_vm_module(vm_module)\n",
        "invoke_text_classification = ctx.modules.module[\"main\"]\n",
        "\n",
        "def classify_text_iree(text):\n",
        "  result = invoke_text_classification(tokenize_input(text))\n",
        "  interpret_output(result[0])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Created IREE driver vmvx: <iree.runtime.binding.HalDriver object at 0x7f86b37fa870>\n",
            "SystemContext driver=<iree.runtime.binding.HalDriver object at 0x7f86b37fa870>\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jvv9zhMwAWgZ",
        "outputId": "2fc7b07c-2612-45a1-e333-3b9a10aaca88"
      },
      "source": [
        "print(\"Invoking text classification with IREE\\n\")\n",
        "print(positive_text)\n",
        "classify_text_iree(positive_text)\n",
        "print()\n",
        "print(negative_text)\n",
        "classify_text_iree(negative_text)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Invoking text classification with IREE\n",
            "\n",
            "This is the best movie I've seen in recent years. Strongly recommend it!\n",
            "Label: Positive\n",
            "Confidence: 0.8997293\n",
            "\n",
            "What a waste of my time.\n",
            "Label: Negative\n",
            "Confidence: 0.62750435\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}
